# ----------------------------------------------------------------------------
# Copyright (c) 2013--, scikit-bio development team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file LICENSE.txt, distributed with this software.
# ----------------------------------------------------------------------------

from typing import Optional, Union, Protocol, runtime_checkable, Tuple, Any, overload

import pandas as pd
import numpy as np
from numpy.typing import ArrayLike as NPArrayLike

from skbio.table import Table
from skbio.util import get_package


# ------------------------------------------------
# ArrayLike (see: skbio.util._array)
# ------------------------------------------------


@runtime_checkable
class StdArray(Protocol):  # pragma: no cover
    r"""Any object compliant with the Python array API standard [1]_.

    Examples are numpy.ndarray, cupy.ndarray, torch.Tensor, jax.Array,
    dask.array.Array, and sparse.SparseArray.

    See Also
    --------
    ._array._get_array

    References
    ----------
    .. [1] https://data-apis.org/array-api/latest/

    Notes
    -----
    This is a Protocol class (PEP 544) which defines the methods and properties
    that an object must implement to be considered compliant with the Python array
    API standards.

    Python array API standards and array-api-compat do not provide an official typing.
    A third-party library array-api-typing exists (https://pypi.org/project/array-api-typing/)
    but is not widely adopted. Typing for the array API is therefore manually defined
    here based on the standard [1]_.

    """

    def __array_namespace__(self, api_version: Optional[str] = None): ...

    # Attributes
    @property
    def dtype(self) -> Any: ...
    @property
    def device(self) -> Any: ...
    @property
    def T(self) -> "StdArray": ...
    @property
    def mT(self) -> "StdArray": ...
    @property
    def ndim(self) -> int: ...
    @property
    def shape(self) -> Tuple[int, ...]: ...
    @property
    def size(self) -> int: ...

    # Comparison operators
    @overload
    def __eq__(self, other: Union[int, float, bool, "StdArray"]) -> "StdArray": ...
    @overload
    def __eq__(self, other: Any) -> bool: ...
    @overload
    def __ne__(self, other: Union[int, float, bool, "StdArray"]) -> "StdArray": ...
    @overload
    def __ne__(self, other: Any) -> bool: ...
    @overload
    def __lt__(self, other: Union[int, float, bool, "StdArray"]) -> "StdArray": ...
    @overload
    def __lt__(self, other: Any) -> bool: ...
    @overload
    def __le__(self, other: Union[int, float, bool, "StdArray"]) -> "StdArray": ...
    @overload
    def __le__(self, other: Any) -> bool: ...
    @overload
    def __gt__(self, other: Union[int, float, bool, "StdArray"]) -> "StdArray": ...
    @overload
    def __gt__(self, other: Any) -> bool: ...
    @overload
    def __ge__(self, other: Union[int, float, bool, "StdArray"]) -> "StdArray": ...
    @overload
    def __ge__(self, other: Any) -> bool: ...

    # Arithmetic operators
    def __pos__(self) -> "StdArray": ...
    def __neg__(self) -> "StdArray": ...
    def __add__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __radd__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __iadd__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __sub__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __rsub__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __isub__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __mul__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __rmul__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __imul__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __truediv__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __rtruediv__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __itruediv__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __floordiv__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __rfloordiv__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __ifloordiv__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __mod__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __rmod__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __imod__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __pow__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __rpow__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __ipow__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __abs__(self) -> "StdArray": ...

    # Array operators
    def __matmul__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __rmatmul__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...
    def __imatmul__(self, other: Union[int, float, "StdArray"]) -> "StdArray": ...

    # Bitwise operators
    def __invert__(self) -> "StdArray": ...
    def __and__(self, other: Union[int, bool, "StdArray"]) -> "StdArray": ...
    def __rand__(self, other: Union[int, bool, "StdArray"]) -> "StdArray": ...
    def __iand__(self, other: Union[int, bool, "StdArray"]) -> "StdArray": ...
    def __or__(self, other: Union[int, bool, "StdArray"]) -> "StdArray": ...
    def __ror__(self, other: Union[int, bool, "StdArray"]) -> "StdArray": ...
    def __ior__(self, other: Union[int, bool, "StdArray"]) -> "StdArray": ...
    def __xor__(self, other: Union[int, bool, "StdArray"]) -> "StdArray": ...
    def __rxor__(self, other: Union[int, bool, "StdArray"]) -> "StdArray": ...
    def __ixor__(self, other: Union[int, bool, "StdArray"]) -> "StdArray": ...
    def __lshift__(self, other: Union[int, "StdArray"]) -> "StdArray": ...
    def __rlshift__(self, other: Union[int, "StdArray"]) -> "StdArray": ...
    def __ilshift__(self, other: Union[int, "StdArray"]) -> "StdArray": ...
    def __rshift__(self, other: Union[int, "StdArray"]) -> "StdArray": ...
    def __rrshift__(self, other: Union[int, "StdArray"]) -> "StdArray": ...
    def __irshift__(self, other: Union[int, "StdArray"]) -> "StdArray": ...

    # Indexing, slicing, and manipulation
    def __getitem__(self, key: Any) -> Any: ...
    def __setitem__(self, key: Any, value: Any) -> None: ...
    def __len__(self) -> int: ...
    def reshape(self, *shape: int, order: str = "C") -> "StdArray": ...
    def squeeze(self, axis: Union[int, Tuple[int, ...], None] = None) -> "StdArray": ...
    def astype(self, dtype: Any, copy: bool = True) -> "StdArray": ...
    def all(
        self, axis: Union[int, Tuple[int, ...], None] = None, keepdims: bool = False
    ) -> "StdArray": ...
    def any(
        self, axis: Union[int, Tuple[int, ...], None] = None, keepdims: bool = False
    ) -> "StdArray": ...

    # Statistical functions
    def min(
        self, axis: Union[int, Tuple[int, ...], None] = None, keepdims: bool = False
    ) -> "StdArray": ...
    def max(
        self, axis: Union[int, Tuple[int, ...], None] = None, keepdims: bool = False
    ) -> "StdArray": ...
    def argmin(self, axis: Union[int, Tuple[int, ...], None] = None) -> "StdArray": ...
    def argmax(self, axis: Union[int, Tuple[int, ...], None] = None) -> "StdArray": ...
    def sum(
        self, axis: Union[int, Tuple[int, ...], None] = None, keepdims: bool = False
    ) -> "StdArray": ...
    def mean(
        self, axis: Union[int, Tuple[int, ...], None] = None, keepdims: bool = False
    ) -> "StdArray": ...
    def std(
        self, axis: Union[int, Tuple[int, ...], None] = None, keepdims: bool = False
    ) -> "StdArray": ...
    def var(
        self, axis: Union[int, Tuple[int, ...], None] = None, keepdims: bool = False
    ) -> "StdArray": ...


ArrayLike = Union[NPArrayLike, StdArray]

# ------------------------------------------------
# TableLike (see: skbio.table._tabular)
# ------------------------------------------------

# Base types which are always available
TableLike = Union[pd.DataFrame, np.ndarray, Table]

# add other types depending on availability
pl = get_package("polars", raise_error=False)
if pl is not None:  # pragma: no cover
    TableLike = Union[TableLike, pl.DataFrame]  # type: ignore[misc]

adt = get_package("anndata", raise_error=False)
if adt is not None:  # pragma: no cover
    TableLike = Union[TableLike, adt.AnnData]  # type: ignore[misc]


# ------------------------------------------------
# SeedLike (see: skbio.util.get_rng)
# ------------------------------------------------

SeedLike = Union[int, np.random.Generator, np.random.RandomState]
